{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GLM4大模型微调入门-命名实体识别任务\n",
    "\n",
    "作者：林泽毅"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.安装环境\n",
    "\n",
    "本案例测试于modelscope==1.14.0、transformers==4.41.2、datasets==2.18.0、peft==0.11.1、accelerate==0.30.1、swanlab==0.3.11、tiktoken==0.7.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install torch swanlab modelscope transformers datasets peft pandas accelerate tiktoken"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果是第一次使用SwanLab，则前往[SwanLab](https://swanlab.cn)注册账号后，在[用户设置](https://swanlab.cn/settings/overview)复制API Key，如果执行下面的代码："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!swanlab login"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 数据集加载\n",
    "\n",
    "1. 在[chinese_ner_sft - huggingface](https://huggingface.co/datasets/qgyd2021/chinese_ner_sft/tree/main/data)下载ccfbdci.jsonl到同级目录下。\n",
    "\n",
    "<img src=\"../assets/ner_dataset.png\" width=600>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 将ccfbdci.jsonl进行处理，转换成new_train.jsonl和new_test.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2.将train.jsonl和test.jsonl进行处理，转换成new_train.jsonl和new_test.jsonl\n",
    "\n",
    "import json\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "def dataset_jsonl_transfer(origin_path, new_path):\n",
    "    \"\"\"\n",
    "    将原始数据集转换为大模型微调所需数据格式的新数据集\n",
    "    \"\"\"\n",
    "    messages = []\n",
    "\n",
    "    # 读取旧的JSONL文件\n",
    "    with open(origin_path, \"r\") as file:\n",
    "        for line in file:\n",
    "            # 解析每一行的json数据\n",
    "            data = json.loads(line)\n",
    "            input_text = data[\"text\"]\n",
    "            entities = data[\"entities\"]\n",
    "            match_names = [\"地点\", \"人名\", \"地理实体\", \"组织\"]\n",
    "            \n",
    "            entity_sentence = \"\"\n",
    "            for entity in entities:\n",
    "                entity_json = dict(entity)\n",
    "                entity_text = entity_json[\"entity_text\"]\n",
    "                entity_names = entity_json[\"entity_names\"]\n",
    "                for name in entity_names:\n",
    "                    if name in match_names:\n",
    "                        entity_label = name\n",
    "                        break\n",
    "                \n",
    "                entity_sentence += f\"\"\"{{\"entity_text\": \"{entity_text}\", \"entity_label\": \"{entity_label}\"}}\"\"\"\n",
    "            \n",
    "            if entity_sentence == \"\":\n",
    "                entity_sentence = \"没有找到任何实体\"\n",
    "            \n",
    "            message = {\n",
    "                \"instruction\": \"\"\"你是一个文本实体识别领域的专家，你需要从给定的句子中提取 地点; 人名; 地理实体; 组织 实体. 以 json 格式输出, 如 {\"entity_text\": \"南京\", \"entity_label\": \"地理实体\"} 注意: 1. 输出的每一行都必须是正确的 json 字符串. 2. 找不到任何实体时, 输出\"没有找到任何实体\". \"\"\",\n",
    "                \"input\": f\"文本:{input_text}\",\n",
    "                \"output\": entity_sentence,\n",
    "            }\n",
    "            \n",
    "            messages.append(message)\n",
    "\n",
    "    # 保存重构后的JSONL文件\n",
    "    with open(new_path, \"w\", encoding=\"utf-8\") as file:\n",
    "        for message in messages:\n",
    "            file.write(json.dumps(message, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "\n",
    "# 加载、处理数据集和测试集\n",
    "train_dataset_path = \"ccfbdci.jsonl\"\n",
    "train_jsonl_new_path = \"ccf_train.jsonl\"\n",
    "\n",
    "if not os.path.exists(train_jsonl_new_path):\n",
    "    dataset_jsonl_transfer(train_dataset_path, train_jsonl_new_path)\n",
    "\n",
    "total_df = pd.read_json(train_jsonl_new_path, lines=True)\n",
    "train_df = total_df[int(len(total_df) * 0.1):]  # 取90%的数据做训练集\n",
    "test_df = total_df[:int(len(total_df) * 0.1)].sample(n=20)  # 随机取10%的数据中的20条做测试集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 下载/加载模型和tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelscope import snapshot_download, AutoTokenizer\n",
    "from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq\n",
    "import torch\n",
    "\n",
    "model_id = \"ZhipuAI/glm-4-9b-chat\"    \n",
    "model_dir = \"./ZhipuAI/glm-4-9b-chat/\"\n",
    "\n",
    "# 在modelscope上下载GLM4模型到本地目录下\n",
    "model_dir = snapshot_download(model_id, cache_dir=\"./\", revision=\"master\")\n",
    "\n",
    "# Transformers加载模型权重\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_dir, device_map=\"auto\", torch_dtype=torch.bfloat16, trust_remote_code=True)\n",
    "model.enable_input_require_grads()  # 开启梯度检查点时，要执行该方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 预处理训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    \"\"\"\n",
    "    将数据集进行预处理, 处理成模型可以接受的格式\n",
    "    \"\"\"\n",
    "\n",
    "    MAX_LENGTH = 384 \n",
    "    input_ids, attention_mask, labels = [], [], []\n",
    "    system_prompt = \"\"\"你是一个文本实体识别领域的专家，你需要从给定的句子中提取 地点; 人名; 地理实体; 组织 实体. 以 json 格式输出, 如 {\"entity_text\": \"南京\", \"entity_label\": \"地理实体\"} 注意: 1. 输出的每一行都必须是正确的 json 字符串. 2. 找不到任何实体时, 输出\"没有找到任何实体\".\"\"\"\n",
    "    \n",
    "    instruction = tokenizer(\n",
    "        f\"<|system|>\\n{system_prompt}<|endoftext|>\\n<|user|>\\n{example['input']}<|endoftext|>\\n<|assistant|>\\n\",\n",
    "        add_special_tokens=False,\n",
    "    )\n",
    "    response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    attention_mask = (\n",
    "        instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1]\n",
    "    )\n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    if len(input_ids) > MAX_LENGTH:  # 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"labels\": labels}   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "train_ds = Dataset.from_pandas(train_df)\n",
    "train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. 设置LORA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    "    target_modules=[\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"activation_func\", \"dense_4h_to_h\"],\n",
    "    inference_mode=False,  # 训练模式\n",
    "    r=8,  # Lora 秩\n",
    "    lora_alpha=32,  # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1,  # Dropout 比例\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"./output/GLM4-NER\",\n",
    "    per_device_train_batch_size=4,\n",
    "    per_device_eval_batch_size=4,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=10,\n",
    "    num_train_epochs=2,\n",
    "    save_steps=100,\n",
    "    learning_rate=1e-4,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True,\n",
    "    report_to=\"none\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from swanlab.integration.huggingface import SwanLabCallback\n",
    "import swanlab\n",
    "\n",
    "swanlab_callback = SwanLabCallback(\n",
    "    project=\"GLM4-NER-fintune\",\n",
    "    experiment_name=\"GLM4-9B-Chat\",\n",
    "    description=\"使用智谱GLM4-9B-Chat模型在NER数据集上微调，实现关键实体识别任务。\",\n",
    "    config={\n",
    "        \"model\": model_id,\n",
    "        \"model_dir\": model_dir,\n",
    "        \"dataset\": \"qgyd2021/chinese_ner_sft\",\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    "    callbacks=[swanlab_callback],\n",
    ")\n",
    "\n",
    "trainer.train()\n",
    "\n",
    "\n",
    "# ====== 训练结束后的预测 ===== #\n",
    "\n",
    "def predict(messages, model, tokenizer):\n",
    "    device = \"cuda\"\n",
    "    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
    "    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)\n",
    "    generated_ids = [\n",
    "        output_ids[len(input_ids) :]\n",
    "        for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "    ]\n",
    "\n",
    "    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "    print(response)\n",
    "\n",
    "    return response\n",
    "    \n",
    "\n",
    "test_text_list = []\n",
    "for index, row in test_df.iterrows():\n",
    "    instruction = row[\"instruction\"]\n",
    "    input_value = row[\"input\"]\n",
    "\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": f\"{instruction}\"},\n",
    "        {\"role\": \"user\", \"content\": f\"{input_value}\"},\n",
    "    ]\n",
    "\n",
    "    response = predict(messages, model, tokenizer)\n",
    "    messages.append({\"role\": \"assistant\", \"content\": f\"{response}\"})\n",
    "    result_text = f\"{messages[0]}\\n\\n{messages[1]}\\n\\n{messages[2]}\"\n",
    "    test_text_list.append(swanlab.Text(result_text, caption=response))\n",
    "\n",
    "swanlab.log({\"Prediction\": test_text_list})\n",
    "swanlab.finish()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
